#include "PhysXTerrainWrapper.h"
#include "CoordinateMapping.h"
#include <NxHeightFieldShape.h>
#include <NxHeightFieldShapeDesc.h>
#include <NxHeightField.h>
#include <NxHeightFieldDesc.h>
#include <NxMat34.h>
#include <NxHeightFieldSample.h> 

cPhysXTerrainWrapper::cPhysXTerrainWrapper(const NxHeightFieldShape& heightFieldShape, const cCoordinateMapping& coordinateMapping, float scale)
{
    NxMat34 transform = heightFieldShape.getGlobalPose();
    transform.M *= scale;
    transform.t *= scale;

    NxHeightFieldShapeDesc heightFieldShapeDesc;
    heightFieldShape.saveToDesc(heightFieldShapeDesc);

    const NxHeightField& heightField = heightFieldShape.getHeightField();

    NxHeightFieldDesc heightFieldDesc;
    heightField.saveToDesc(heightFieldDesc);

    _stride = heightFieldDesc.nbColumns;
    _strides = heightFieldDesc.nbRows;
    int32_t points = _stride * _strides;
    std::vector<char> samplesBuffer(points * heightFieldDesc.sampleStride);
    {
        NxU32 written = heightField.saveCells(&samplesBuffer[0], points * heightFieldDesc.sampleStride);
        assert(written == points * heightFieldDesc.sampleStride);
        heightFieldDesc.samples = (void*)(&samplesBuffer[0]);
    }

    _pointCoords.reserve(points * 3);
    _tessellationFlags.reserve(points);
    const char* bytePtr = (const char*)(heightFieldDesc.samples);
    assert(heightFieldDesc.format == NX_HF_S16_TM);
    int32_t faceIndex = 0;
    for(int32_t row = 0; row != _strides; ++row)
    {
        NxVec3 localPoint;
        localPoint.x = static_cast<NxReal>(row) * heightFieldShapeDesc.rowScale;
        for(int32_t column = 0; column != _stride; ++column)
        {
            const NxHeightFieldSample* samplePtr = (NxHeightFieldSample*)bytePtr;
            const NxHeightFieldSample& sample = *samplePtr;
            localPoint.y = static_cast<NxReal>(sample.height) * heightFieldShapeDesc.heightScale;
            localPoint.z = static_cast<NxReal>(column) * heightFieldShapeDesc.rowScale;
            NxVec3 worldPoint;
            transform.multiply(localPoint, worldPoint);
            _pointCoords.push_back(static_cast<int32_t>(worldPoint.x));
            _pointCoords.push_back(static_cast<int32_t>(worldPoint.y));
            _pointCoords.push_back(static_cast<int32_t>(worldPoint.z));
            _tessellationFlags.push_back(sample.tessFlag);
            if(row + 1 != _strides && column + 1 != _stride)
            {
                if(sample.materialIndex0 != heightFieldShapeDesc.holeMaterial)
                {
                    _holesRemap.push_back(faceIndex);
                }
                if(sample.materialIndex1 != heightFieldShapeDesc.holeMaterial)
                {
                    _holesRemap.push_back(faceIndex + 1);
                }
            }
            faceIndex += 2;
            bytePtr += heightFieldDesc.sampleStride;
        }
    }
    
    for(int32_t i = 0; i != points; ++i)
    {
        coordinateMapping.applyTo(&_pointCoords[i * 3]);
    }
}


// iFaceVertexMesh interface

int32_t
cPhysXTerrainWrapper::faces() const
{
    return static_cast<int32_t>(_holesRemap.size());
}
int32_t
cPhysXTerrainWrapper::vertices() const
{
    assert(_stride * _strides * 3 == static_cast<int32_t>(_pointCoords.size()));
    return _stride * _strides;
}
int32_t
cPhysXTerrainWrapper::vertexIndex(int32_t face, int32_t vertexInFace) const
{
    assert(face >= 0 && face < static_cast<int32_t>(_holesRemap.size()));
    int32_t remappedFace = _holesRemap[face];
    int32_t square = remappedFace / 2;
    int32_t faceInSquare = (remappedFace & 1);
    int32_t vertexInSquare;
    if(_tessellationFlags[square])
    {
        const int32_t vertLookup[] = {0, 3, 1, 1, 3, 2};
        vertexInSquare = vertLookup[faceInSquare * 3 + vertexInFace];
    }
    else
    {
        const int32_t vertLookup[] = {0, 2, 1, 0, 3, 2};
        vertexInSquare = vertLookup[faceInSquare * 3 + vertexInFace];
    }
    int32_t result = square;
    switch(vertexInSquare)
    {
    default:
    case 0:
        break;
    case 1:
        ++result;
        break;
    case 2:
        result += 1 + _stride;
        break;
    case 3:
        result += _stride;
        break;
    }
    return result;
}
int32_t
cPhysXTerrainWrapper::vertexX(int32_t vertex) const
{
    return _pointCoords[vertex * 3];
}
int32_t
cPhysXTerrainWrapper::vertexY(int32_t vertex) const
{
    return _pointCoords[vertex * 3 + 1];
}
float
cPhysXTerrainWrapper::vertexZ(int32_t vertex) const
{
    return static_cast<float>(_pointCoords[vertex * 3 + 2]);
}
int32_t
cPhysXTerrainWrapper::faceAttribute(int32_t face, int32_t attributeIndex) const
{
    if(attributeIndex == PE_FaceAttribute_SectionID)
    {
        return 0; // mark all faces as sectionID == 0 (first terrain layer)
    }
    return -1;
}
